""" DenseHead. """
import numpy as np
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops as P

class DenseHead(nn.Cell):
    """
        DenseHead architecture, full-connected layer.

        Args:
            n_classes (int): Number of classes.
            in_channel(int): Input channel size.

        Returns:
            Tensor, output tensor.
    """
    def __init__(self, n_class, in_channel, model_size="1.0x"):
        super(DenseHead, self).__init__()

        self.model_size = model_size
        self.in_channels = in_channel
        self.out_channels = n_class

        if self.model_size == '2.0x':
            self.dropout = nn.Dropout(keep_prob=0.8)

        head = [nn.Dense(in_channels=in_channel, out_channels=n_class, has_bias=False)]
        self.head = nn.SequentialCell(head)
        self._initialize_weights()

    def construct(self, x):
        if self.model_size == '2.0x':
            x = self.dropout(x)
        x = P.Reshape()(x, (-1, self.in_channels,))
        x = self.head(x)
        return x

    def _initialize_weights(self):
        """
        Initialize weights.

        Args:

        Returns:
            None.

        Examples:
            >>> _initialize_weights()
        """
        for name, m in self.cells_and_names():
            if isinstance(m, nn.Conv2d):
                if 'first' in name:
                    m.weight.set_data(Tensor(np.random.normal(0, 0.01,
                                                              m.weight.data.shape).astype("float32")))
                else:
                    m.weight.set_data(Tensor(np.random.normal(0, 1.0 / m.weight.data.shape[1],
                                                              m.weight.data.shape).astype("float32")))

            if isinstance(m, nn.Dense):
                m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))

    @property
    def get_head(self):
        """

        head is a private object.

        """
        return self.head
